#!/usr/bin/python3
#################################################################################
# IT IS RECOMMENDED NOT TO EDIT THIS SCRIPT EXCEPT IN THE DESIGNATED LOCATIONS. #
#################################################################################

###################################
# CHECK FOR REQUIRED DEPENDENCIES #
###################################
import sys
if sys.version_info < (3, 8):
    print("This script requires Python 3.8 or greater.")
    sys.exit(1)

try:
    import aiohttp
    import click
    import dataclasses_json
    import rich
except ImportError:
    print("The following libraries are required: aiohttp, click, dataclasses-json, rich")
    print()
    print("These libraries can be installed using the following command:")
    print()
    print("\t/path/to/pip install aiohttp click dataclasses-json rich")
    print()
    print("Note: It is recommended to use a virtual environment if possible.")
    sys.exit(1)


###########
# IMPORTS #
###########
from typing import Optional, List, Dict, Set, Any, Union

from dataclasses import dataclass, field
from dataclasses_json import dataclass_json, LetterCase, DataClassJsonMixin, config
from json import dumps as json_dumps
from rich.console import Console
from rich.table import Table
from traceback import StackSummary

import asyncio
import base64
import datetime
import os
import traceback


###########################
# API ENDPOINT DEFINITION #
###########################
@dataclass
class APIEndpoint:
    """
    An API endpoint to connect to.

    :param host: the IP / hostname of the SBC
    :param username: the username of the API user to use
    :param password: the password of the API user to use
    """
    host: str
    username: str
    password: str = field(repr=False)


############################################
# SCRIPT CONFIGURATION - EDIT THIS SECTION #
############################################
# A list containing the API endpoints to connect to for fetching the telemetry data. An API user / password combination
#  must be added to each SBC that needs to have telemetry data fetched from.
TELEMETRY_ENDPOINTS: List[APIEndpoint] = [
    # APIEndpoint(host="127.0.0.1", username="user", password="password"),
]


###################
# SCRIPT CONTENTS #
###################
class APIError:
    code: int
    status: str
    message: str
    details: Optional[Any]


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class APIErrorResponse(DataClassJsonMixin):
    error: APIError


class APIException(Exception):
    def __init__(self,
                 message: str,
                 target_endpoint_host: str,
                 response: Optional[APIErrorResponse] = None,
                 cause: Optional[Exception] = None,
                 cause_stack: Optional[StackSummary] = None):
        self.message = message
        self.target_endpoint_host = target_endpoint_host
        self.response = response.error if response is not None else None
        self.cause = cause
        self.cause_stack = cause_stack


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class UsageTelemetryDataInboundOutboundData(DataClassJsonMixin):
    call_count: int
    total_duration_in_minutes: int


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class UsageTelemetryData(DataClassJsonMixin):
    summary_date: str
    node_id: int
    partner_id: str
    customer_id: str
    call_type: str
    start_date: str
    end_date: str
    inbound_data: UsageTelemetryDataInboundOutboundData
    outbound_data: UsageTelemetryDataInboundOutboundData
    user_uri_list: List[str]


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class UsagePerformanceDataJoinTransferReasonData(DataClassJsonMixin):
    reason_code: int
    count: int


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class UsagePerformanceDataJoinTransferData(DataClassJsonMixin):
    total_call_count: int
    success_count: int
    failures: List[UsagePerformanceDataJoinTransferReasonData]


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class PerformanceTelemetryData(DataClassJsonMixin):
    summary_date: str
    node_id: int
    partner_id: str
    customer_id: str
    call_type: str
    start_date: str
    end_date: str
    join_data: UsagePerformanceDataJoinTransferData
    transfer_data: UsagePerformanceDataJoinTransferData


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class TelemetryData(DataClassJsonMixin):
    usage: List[UsageTelemetryData] = field(metadata=config(field_name="usageData"))
    performance: List[PerformanceTelemetryData] = field(metadata=config(field_name="performanceData"))


@dataclass_json(letter_case=LetterCase.CAMEL)
@dataclass
class MergedTelemetryData(DataClassJsonMixin):
    summary_date: str
    partner_id: str
    customer_id: str
    call_type: str
    start_date: str
    end_date: str

    inbound_call_count: int = 0
    inbound_total_duration_minutes: int = 0
    outbound_call_count: int = 0
    outbound_total_duration_minutes: int = 0
    user_uri_list: Set[str] = field(default_factory=set)

    join_total_call_count: int = 0
    join_success_count: int = 0
    join_failures: Dict[int, int] = field(default_factory=dict)
    transfer_total_call_count: int = 0
    transfer_success_count: int = 0
    transfer_failures: Dict[int, int] = field(default_factory=dict)

    inbound_average_call_count: float = 0.0
    outbound_average_call_count: float = 0.0
    total_unique_users: int = 0
    join_success_pct: float = 0.0
    transfer_success_pct: float = 0.0

    def finish(self):
        if self.inbound_total_duration_minutes > 0 and self.inbound_call_count > 0:
            self.inbound_average_call_count = self.inbound_total_duration_minutes / self.inbound_call_count
        if self.outbound_total_duration_minutes > 0 and self.outbound_call_count > 0:
            self.outbound_average_call_count = self.outbound_total_duration_minutes / self.outbound_call_count
        self.total_unique_users = len(self.user_uri_list)

        if self.join_success_count > 0 and self.join_total_call_count > 0:
            self.join_success_pct = self.join_success_count / self.join_total_call_count
        if self.transfer_success_count > 0 and self.transfer_total_call_count > 0:
            self.transfer_success_pct = self.transfer_success_count / self.transfer_total_call_count


def __encode_credentials(username: str, password: str) -> str:
    """
    Encodes the provided credentials using Base64 for use with Basic authentication.
    """
    return base64.standard_b64encode(f"{username}:{password}".encode("UTF-8")).decode("UTF-8")


def __format_iso_8601(utc_date: datetime.datetime) -> str:
    """
    Format a date using the ISO-8601 standard.  This differs from the default python implementation as it expects a UTC
    date and appends a Z to the final output timestamp.
    """
    return utc_date.strftime('%Y-%m-%dT%H:%M:%SZ')


async def fetch_auth_token(session: aiohttp.ClientSession, endpoint: APIEndpoint) -> Union[str, APIException]:
    """
    Fetch the authentication token from the authentication endpoint from the specified SBC.
    """
    headers = {"Authorization": f"Basic {__encode_credentials(endpoint.username, endpoint.password)}"}
    async with session.post(f"https://{endpoint.host}/api/auth/v1/token", headers=headers, ssl=False) as response:
        try:
            response_json = await response.json()
        except aiohttp.ContentTypeError as e:
            return APIException("Received Content-Type error, API is most likely not running.",
                                endpoint.host,
                                cause=e,
                                cause_stack=traceback.extract_stack())
        except Exception as e:
            return APIException("Caught unexpected exception while communicating with the API.",
                                endpoint.host,
                                cause=e,
                                cause_stack=traceback.extract_stack())

        if "error" in response_json:
            return APIException("Failed to authenticate with the authentication API.",
                                endpoint.host,
                                APIErrorResponse.from_dict(response_json))

        return response_json


async def fetch_from_sbc(endpoint: APIEndpoint) -> Union[TelemetryData, APIException]:
    """
    Fetch the data from the SBC using the existing asyncio context.
    """
    async with aiohttp.ClientSession() as session:
        token = await fetch_auth_token(session, endpoint)
        if isinstance(token, APIException):
            return token

        headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {token}",
        }

        host = endpoint.host
        start_date = __format_iso_8601(datetime.datetime.utcnow() - datetime.timedelta(days=91))
        end_date = __format_iso_8601(datetime.datetime.utcnow())

        url = f"https://{host}/api/monitoring/v1/telemetry?collapse=true&startDate={start_date}&endDate={end_date}"

        async with session.get(url, headers=headers, ssl=False) as response:
            try:
                response_json = await response.json()
            except aiohttp.ContentTypeError as e:
                return APIException("Received Content-Type error, API is most likely not running.",
                                    endpoint.host,
                                    cause=e,
                                    cause_stack=traceback.extract_stack())
            except Exception as e:
                return APIException("Caught unexpected exception while communicating with the API.",
                                    endpoint.host,
                                    cause=e,
                                    cause_stack=traceback.extract_stack())

            if "error" in response_json:
                return APIException("Failed to fetch data from telemetry API.",
                                    endpoint.host,
                                    APIErrorResponse.from_dict(response_json))

        return TelemetryData.from_dict(response_json)


def merge_telemetry_data(stderr_console: Console, data_list: List[Union[TelemetryData, APIException]]) -> List[MergedTelemetryData]:
    """
    Iterate over the provided data list and merge the data into the final list for display or export.
    """
    merged_data: Dict[str, MergedTelemetryData] = {}

    for data in data_list:
        if isinstance(data, APIException):
            stderr_console.print("".join(traceback.format_exception(type(data), data, data.__traceback__)))
            if data.cause:
                stderr_console.print("Cause:")
                stderr_console.print("".join(traceback.format_exception(type(data.cause), data.cause, data.cause.__traceback__)))
            if data.cause_stack:
                stderr_console.print("Stack Trace:")
                stderr_console.print("".join(data.cause_stack.format()))
            continue

        for usage_data in data.usage:
            key = f"{usage_data.partner_id}-{usage_data.customer_id}-{usage_data.call_type}"
            if key not in merged_data:
                merged_data[key] = MergedTelemetryData(
                    summary_date=usage_data.summary_date,
                    partner_id=usage_data.partner_id,
                    customer_id=usage_data.customer_id,
                    call_type=usage_data.call_type,
                    start_date=usage_data.start_date,
                    end_date=usage_data.end_date,
                )
            merged_data[key].inbound_call_count += usage_data.inbound_data.call_count
            merged_data[key].inbound_total_duration_minutes += usage_data.inbound_data.total_duration_in_minutes
            merged_data[key].outbound_call_count += usage_data.outbound_data.call_count
            merged_data[key].outbound_total_duration_minutes += usage_data.outbound_data.total_duration_in_minutes
            merged_data[key].user_uri_list.update(usage_data.user_uri_list)

        for performance_data in data.performance:
            key = f"{performance_data.partner_id}-{performance_data.customer_id}-{performance_data.call_type}"
            if key not in merged_data:
                merged_data[key] = MergedTelemetryData(
                    summary_date=performance_data.summary_date,
                    partner_id=performance_data.partner_id,
                    customer_id=performance_data.customer_id,
                    call_type=performance_data.call_type,
                    start_date=performance_data.start_date,
                    end_date=performance_data.end_date,
                )
            merged_data[key].join_total_call_count += performance_data.join_data.total_call_count
            merged_data[key].join_success_count += performance_data.join_data.success_count
            merged_data[key].join_failures.update({v.reason_code: v.count for v in performance_data.join_data.failures})
            merged_data[key].transfer_total_call_count += performance_data.transfer_data.total_call_count
            merged_data[key].transfer_success_count += performance_data.transfer_data.success_count
            merged_data[key].transfer_failures.update({v.reason_code: v.count for v in performance_data.transfer_data.failures})

    result: List[MergedTelemetryData] = []
    for value in merged_data.values():
        value.finish()
        result.append(value)

    return result


@click.command()
@click.option("--json", is_flag=True, help="Output the telemetry data as JSON to STDOUT.")
def main(json: bool):
    console = Console()
    stderr_console = Console(stderr=True)

    if os.name == "nt":
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
    data_list = [asyncio.run(fetch_from_sbc(endpoint)) for endpoint in TELEMETRY_ENDPOINTS]

    merged_data_list = merge_telemetry_data(stderr_console, data_list)
    if json:
        dict_list = list(map(lambda data: data.to_dict(), merged_data_list))
        console.print(json_dumps(dict_list))
    else:
        console.print("Tip: Use the --json argument to output the data in JSON format instead of tabular format.")

        if len(merged_data_list) <= 0:
            caption = "No telemetry reports found."
        else:
            caption = f"Found and collated {len(merged_data_list)} telemetry report records."
        table = Table(title="Telemetry Data", caption=caption)

        table.add_column("Partner ID")
        table.add_column("Customer ID")
        table.add_column("Start Date")
        table.add_column("End Date")
        table.add_column("Inbound Total Call Count")
        table.add_column("Inbound Total Call Duration")
        table.add_column("Inbound Average Call Duration")
        table.add_column("Outbound Total Call Count")
        table.add_column("Outbound Total Call Duration")
        table.add_column("Outbound Average Call Duration")
        table.add_column("Number of Users")
        table.add_column("Joined Calls")
        table.add_column("Success % of Joined Calls")
        table.add_column("Transferred Calls")
        table.add_column("Success % of Transferred Calls")

        for data in merged_data_list:
            table.add_row(
                data.partner_id,
                data.customer_id,
                data.start_date,
                data.end_date,
                str(data.inbound_call_count),
                str(data.inbound_total_duration_minutes),
                str(data.inbound_average_call_count),
                str(data.outbound_call_count),
                str(data.outbound_total_duration_minutes),
                str(data.outbound_average_call_count),
                str(data.total_unique_users),
                str(data.join_total_call_count),
                f"{round(data.join_success_pct * 100)}%",
                str(data.transfer_total_call_count),
                f"{round(data.transfer_success_pct * 100)}%",
            )

        console.print(table)


if __name__ == "__main__":
    main()
